#  This file is part of Pynguin.
#
#  SPDX-FileCopyrightText: 2019–2021 Pynguin Contributors
#
#  SPDX-License-Identifier: LGPL-3.0-or-later
#
"""An abstract test exporter"""
import ast
import os
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import List, Optional, Set, Tuple, Union

import astor

import pynguin.testcase.testcase as tc
import pynguin.testcase.testcase_to_ast as tc_to_ast
from pynguin.utils.namingscope import NamingScope


# pylint: disable=too-few-public-methods
class AbstractTestExporter(metaclass=ABCMeta):
    """An abstract test exporter"""

    def __init__(self, wrap_code: bool = False) -> None:
        self._wrap_code = wrap_code

    @abstractmethod
    def export_sequences(
        self, path: Union[str, os.PathLike], test_cases: List[tc.TestCase]
    ):
        """Exports test cases to an AST module, where each test case is a method.

        Args:
            test_cases: A list of test cases.
            path: Destination file for the exported test case.

        Returns:  # noqa: DAR202
            An AST module that contains the methods for these test cases.
        """

    def _transform_to_asts(
        self,
        test_cases: List[tc.TestCase],
    ) -> Tuple[NamingScope, Set[str], List[List[ast.stmt]]]:
        visitor = tc_to_ast.TestCaseToAstVisitor(wrap_code=self._wrap_code)
        for test_case in test_cases:
            test_case.accept(visitor)
        return visitor.module_aliases, visitor.common_modules, visitor.test_case_asts

    @staticmethod
    def _create_ast_imports(
        module_aliases: NamingScope, common_modules: Optional[Set[str]] = None
    ) -> List[ast.stmt]:
        imports: List[ast.stmt] = []
        if common_modules is not None:
            for module in common_modules:
                imports.append(ast.Import(names=[ast.alias(name=module, asname=None)]))
        for module_name in module_aliases.known_name_indices:
            imports.append(
                ast.Import(
                    names=[
                        ast.alias(
                            name=module_name,
                            asname=module_aliases.get_name(module_name),
                        )
                    ]
                )
            )
        return imports

    @staticmethod
    def _create_functions(
        asts: List[List[ast.stmt]], with_self_arg: bool
    ) -> List[ast.stmt]:
        functions: List[ast.stmt] = []
        for i, nodes in enumerate(asts):
            function_name = f"case_{i}"
            if len(nodes) == 0:
                nodes = [ast.Pass()]
            function_node = AbstractTestExporter.__create_function_node(
                function_name, nodes, with_self_arg
            )
            functions.append(function_node)
        return functions

    @staticmethod
    def __create_function_node(
        function_name: str, nodes: List[ast.stmt], with_self_arg: bool
    ) -> ast.FunctionDef:
        function_node = ast.FunctionDef(
            name=f"test_{function_name}",
            args=ast.arguments(
                args=[ast.Name(id="self", ctx="Param")] if with_self_arg else [],
                defaults=[],
                vararg=None,
                kwarg=None,
                kwonlyargs=[],
                kw_defaults=[],
            ),
            body=nodes,
            decorator_list=[],
            returns=None,
        )
        return function_node

    @staticmethod
    def _save_ast_to_file(path: Union[str, os.PathLike], module: ast.Module) -> None:
        """Saves an AST module to a file.

        Args:
            path: Destination file
            module: The AST module
        """
        target = Path(path)
        target.parent.mkdir(parents=True, exist_ok=True)
        with target.open(mode="w") as file:
            file.write("# Automatically generated by Pynguin.\n")
            file.write(astor.to_source(module))
